invalidate a path if loading via metadata and entry is not in hdf5 (to avoid reparsing my metadata since I'm using a partial copy of my dataset at the moment)

This commit is contained in:
mrq 2025-02-10 14:43:15 -06:00
parent 075ffef68a
commit b3f9b76fd9
3 changed files with 101 additions and 19 deletions

View File

@ -762,8 +762,13 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
# add to duration bucket
k = key(id, entry)
# double check if in HDF5
if cfg.dataset.use_hdf5 and k not in cfg.hdf5:
return False
# add to duration bucket
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][k] = duration

View File

@ -48,18 +48,19 @@ tts = None
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
@cache
def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large'):
def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large', finetune=False):
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large')
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
if not finetune_path.exists():
download_model(finetune_path)
if finetune:
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
if not finetune_path.exists():
download_model(finetune_path)
state_dict = torch.load( finetune_path )
state_dict = state_dict['model']
del state_dict['loss_calculator.projection.weight']
model.load_state_dict( state_dict )
state_dict = torch.load( finetune_path )
state_dict = state_dict['model']
del state_dict['loss_calculator.projection.weight']
model.load_state_dict( state_dict )
model = model.to(device=device, dtype=coerce_dtype(dtype))
model = model.eval()
@ -88,6 +89,7 @@ def speaker_similarity_embedding(
audio,
**model_kwargs,
):
model_kwargs["finetune"] = True
device = model_kwargs.get("device", "cuda")
dtype = model_kwargs.get("dtype", "float16")
@ -299,6 +301,80 @@ def batch_similar_utterances(
return sorted_similarities
"""
# (Attempts to) group speakers based on top-k cosine similarities, by pooling together similar utterances together
# It sort of works, but the WavLM finetuned for speaker similarities leaves some false positives without decent threshold values
"""
def sort_similarities(
path,
out_path=None,
threshold=0.8,
orphan_threshold=0.6,
):
if not out_path:
out_path = path.parent / "speakers.json"
orphans = []
speakers = []
for filename, similarities in metadata.items():
target = False
# find any existing buckets
for i, pool in enumerate(speakers):
for (idx, name, score) in similarities:
if score and score < threshold:
continue
if name in pool:
target = i
break
if target != False:
break
# not found, create new bucket
if target == False:
pool = [ name for (idx, name, score) in similarities if (not score or score > threshold) ]
if filename not in pool:
pool.append(filename)
# orphan, check later
if len(pool) == 1:
orphans += pool
else:
speakers.append(pool)
continue
# insert entries into pool
if filename not in speakers[target]:
speakers[target].append(filename)
for (idx, name, score) in similarities:
if score and score < threshold:
continue
if name not in speakers[target]:
speakers[target].append(name)
# shove orphans to best scoring pool
for filename in orphans:
target = False
for (idx, name, score) in metadata[filename]:
if score and score < orphan_threshold:
continue
for i, pool in enumerate(speakers):
if name in pool:
target = i
break
if target != False:
continue
if target == False:
continue
speakers[target].append(filename)
json_write( speakers, out_path )
def main():
parser = argparse.ArgumentParser()
@ -423,17 +499,17 @@ def main():
return_features=True,
)
features_json = {}
for k, v in features.items():
features_json[k] = [ x.item() for x in v ]
if args.out_path is not None:
features_json = {}
for k, v in features.items():
features_json[k] = [ x.item() for x in v ]
json_write( similarities, args.out_path / "similarities.json" )
json_write( features_json, args.out_path / "embeddings.json" )
# and print
for filename, sim in similarities.items():
print(f'{filename}: {sim}')
else:
# and print
for filename, sim in similarities.items():
print(f'{filename}: {sim}')
else:
raise Exception("!")

View File

@ -25,6 +25,7 @@ from tqdm.auto import tqdm
from pathlib import Path
from ..utils import coerce_dtype
from ..utils.io import json_read, json_write
def pad(num, zeroes):
return str(num).zfill(zeroes+1)
@ -308,7 +309,7 @@ def transcribe_batch(
outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json')
if outpath.exists():
metadata = json.loads(open(outpath, 'r', encoding='utf-8').read())
metadata = json_read( outpath )
else:
os.makedirs(f'./{output_metadata}/{dataset_name}/{speaker_id}/', exist_ok=True)
metadata = {}
@ -327,7 +328,7 @@ def transcribe_batch(
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
json_write( metadata, outpath )
def main():
parser = argparse.ArgumentParser()