diff --git a/vall_e/data.py b/vall_e/data.py index e7b58fe..05a2c5d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -762,8 +762,13 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): phones = entry['phones'] if "phones" in entry else 0 duration = entry['duration'] if "duration" in entry else 0 - # add to duration bucket k = key(id, entry) + + # double check if in HDF5 + if cfg.dataset.use_hdf5 and k not in cfg.hdf5: + return False + + # add to duration bucket if type not in _durations_map: _durations_map[type] = {} _durations_map[type][k] = duration diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 942b679..c56aeb0 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -48,18 +48,19 @@ tts = None # this is for computing SIM-O, but can probably technically be used for scoring similar utterances @cache -def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large'): +def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large', finetune=False): from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large') - finetune_path = Path("./data/models/wavlm_large_finetune.pth") - if not finetune_path.exists(): - download_model(finetune_path) + if finetune: + finetune_path = Path("./data/models/wavlm_large_finetune.pth") + if not finetune_path.exists(): + download_model(finetune_path) - state_dict = torch.load( finetune_path ) - state_dict = state_dict['model'] - del state_dict['loss_calculator.projection.weight'] - model.load_state_dict( state_dict ) + state_dict = torch.load( finetune_path ) + state_dict = state_dict['model'] + del state_dict['loss_calculator.projection.weight'] + model.load_state_dict( state_dict ) model = model.to(device=device, dtype=coerce_dtype(dtype)) model = model.eval() @@ -88,6 +89,7 @@ def speaker_similarity_embedding( audio, **model_kwargs, ): + model_kwargs["finetune"] = True device = model_kwargs.get("device", "cuda") dtype = model_kwargs.get("dtype", "float16") @@ -299,6 +301,80 @@ def batch_similar_utterances( return sorted_similarities +""" +# (Attempts to) group speakers based on top-k cosine similarities, by pooling together similar utterances together +# It sort of works, but the WavLM finetuned for speaker similarities leaves some false positives without decent threshold values +""" +def sort_similarities( + path, + out_path=None, + threshold=0.8, + orphan_threshold=0.6, +): + if not out_path: + out_path = path.parent / "speakers.json" + + orphans = [] + speakers = [] + + for filename, similarities in metadata.items(): + target = False + + # find any existing buckets + for i, pool in enumerate(speakers): + for (idx, name, score) in similarities: + if score and score < threshold: + continue + if name in pool: + target = i + break + + if target != False: + break + # not found, create new bucket + if target == False: + pool = [ name for (idx, name, score) in similarities if (not score or score > threshold) ] + if filename not in pool: + pool.append(filename) + + # orphan, check later + if len(pool) == 1: + orphans += pool + else: + speakers.append(pool) + continue + + # insert entries into pool + if filename not in speakers[target]: + speakers[target].append(filename) + + for (idx, name, score) in similarities: + if score and score < threshold: + continue + if name not in speakers[target]: + speakers[target].append(name) + + # shove orphans to best scoring pool + for filename in orphans: + target = False + for (idx, name, score) in metadata[filename]: + if score and score < orphan_threshold: + continue + for i, pool in enumerate(speakers): + if name in pool: + target = i + break + if target != False: + continue + + if target == False: + continue + + speakers[target].append(filename) + + + json_write( speakers, out_path ) + def main(): parser = argparse.ArgumentParser() @@ -423,17 +499,17 @@ def main(): return_features=True, ) - features_json = {} - for k, v in features.items(): - features_json[k] = [ x.item() for x in v ] - if args.out_path is not None: + features_json = {} + for k, v in features.items(): + features_json[k] = [ x.item() for x in v ] + json_write( similarities, args.out_path / "similarities.json" ) json_write( features_json, args.out_path / "embeddings.json" ) - - # and print - for filename, sim in similarities.items(): - print(f'{filename}: {sim}') + else: + # and print + for filename, sim in similarities.items(): + print(f'{filename}: {sim}') else: raise Exception("!") diff --git a/vall_e/emb/transcribe.py b/vall_e/emb/transcribe.py index 6f3bbe2..e496eb0 100644 --- a/vall_e/emb/transcribe.py +++ b/vall_e/emb/transcribe.py @@ -25,6 +25,7 @@ from tqdm.auto import tqdm from pathlib import Path from ..utils import coerce_dtype +from ..utils.io import json_read, json_write def pad(num, zeroes): return str(num).zfill(zeroes+1) @@ -308,7 +309,7 @@ def transcribe_batch( outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json') if outpath.exists(): - metadata = json.loads(open(outpath, 'r', encoding='utf-8').read()) + metadata = json_read( outpath ) else: os.makedirs(f'./{output_metadata}/{dataset_name}/{speaker_id}/', exist_ok=True) metadata = {} @@ -327,7 +328,7 @@ def transcribe_batch( metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype ) - open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata)) + json_write( metadata, outpath ) def main(): parser = argparse.ArgumentParser()