invalidate a path if loading via metadata and entry is not in hdf5 (to avoid reparsing my metadata since I'm using a partial copy of my dataset at the moment)
This commit is contained in:
parent
075ffef68a
commit
b3f9b76fd9
|
@ -762,8 +762,13 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
|
||||
# add to duration bucket
|
||||
k = key(id, entry)
|
||||
|
||||
# double check if in HDF5
|
||||
if cfg.dataset.use_hdf5 and k not in cfg.hdf5:
|
||||
return False
|
||||
|
||||
# add to duration bucket
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][k] = duration
|
||||
|
|
|
@ -48,18 +48,19 @@ tts = None
|
|||
|
||||
# this is for computing SIM-O, but can probably technically be used for scoring similar utterances
|
||||
@cache
|
||||
def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large'):
|
||||
def _load_sim_model(device="cuda", dtype="float16", model_name='microsoft/wavlm-large', finetune=False):
|
||||
from ..utils.ext.ecapa_tdnn import ECAPA_TDNN_SMALL
|
||||
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large')
|
||||
|
||||
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
|
||||
if not finetune_path.exists():
|
||||
download_model(finetune_path)
|
||||
if finetune:
|
||||
finetune_path = Path("./data/models/wavlm_large_finetune.pth")
|
||||
if not finetune_path.exists():
|
||||
download_model(finetune_path)
|
||||
|
||||
state_dict = torch.load( finetune_path )
|
||||
state_dict = state_dict['model']
|
||||
del state_dict['loss_calculator.projection.weight']
|
||||
model.load_state_dict( state_dict )
|
||||
state_dict = torch.load( finetune_path )
|
||||
state_dict = state_dict['model']
|
||||
del state_dict['loss_calculator.projection.weight']
|
||||
model.load_state_dict( state_dict )
|
||||
|
||||
model = model.to(device=device, dtype=coerce_dtype(dtype))
|
||||
model = model.eval()
|
||||
|
@ -88,6 +89,7 @@ def speaker_similarity_embedding(
|
|||
audio,
|
||||
**model_kwargs,
|
||||
):
|
||||
model_kwargs["finetune"] = True
|
||||
device = model_kwargs.get("device", "cuda")
|
||||
dtype = model_kwargs.get("dtype", "float16")
|
||||
|
||||
|
@ -299,6 +301,80 @@ def batch_similar_utterances(
|
|||
|
||||
return sorted_similarities
|
||||
|
||||
"""
|
||||
# (Attempts to) group speakers based on top-k cosine similarities, by pooling together similar utterances together
|
||||
# It sort of works, but the WavLM finetuned for speaker similarities leaves some false positives without decent threshold values
|
||||
"""
|
||||
def sort_similarities(
|
||||
path,
|
||||
out_path=None,
|
||||
threshold=0.8,
|
||||
orphan_threshold=0.6,
|
||||
):
|
||||
if not out_path:
|
||||
out_path = path.parent / "speakers.json"
|
||||
|
||||
orphans = []
|
||||
speakers = []
|
||||
|
||||
for filename, similarities in metadata.items():
|
||||
target = False
|
||||
|
||||
# find any existing buckets
|
||||
for i, pool in enumerate(speakers):
|
||||
for (idx, name, score) in similarities:
|
||||
if score and score < threshold:
|
||||
continue
|
||||
if name in pool:
|
||||
target = i
|
||||
break
|
||||
|
||||
if target != False:
|
||||
break
|
||||
# not found, create new bucket
|
||||
if target == False:
|
||||
pool = [ name for (idx, name, score) in similarities if (not score or score > threshold) ]
|
||||
if filename not in pool:
|
||||
pool.append(filename)
|
||||
|
||||
# orphan, check later
|
||||
if len(pool) == 1:
|
||||
orphans += pool
|
||||
else:
|
||||
speakers.append(pool)
|
||||
continue
|
||||
|
||||
# insert entries into pool
|
||||
if filename not in speakers[target]:
|
||||
speakers[target].append(filename)
|
||||
|
||||
for (idx, name, score) in similarities:
|
||||
if score and score < threshold:
|
||||
continue
|
||||
if name not in speakers[target]:
|
||||
speakers[target].append(name)
|
||||
|
||||
# shove orphans to best scoring pool
|
||||
for filename in orphans:
|
||||
target = False
|
||||
for (idx, name, score) in metadata[filename]:
|
||||
if score and score < orphan_threshold:
|
||||
continue
|
||||
for i, pool in enumerate(speakers):
|
||||
if name in pool:
|
||||
target = i
|
||||
break
|
||||
if target != False:
|
||||
continue
|
||||
|
||||
if target == False:
|
||||
continue
|
||||
|
||||
speakers[target].append(filename)
|
||||
|
||||
|
||||
json_write( speakers, out_path )
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
@ -423,17 +499,17 @@ def main():
|
|||
return_features=True,
|
||||
)
|
||||
|
||||
features_json = {}
|
||||
for k, v in features.items():
|
||||
features_json[k] = [ x.item() for x in v ]
|
||||
|
||||
if args.out_path is not None:
|
||||
features_json = {}
|
||||
for k, v in features.items():
|
||||
features_json[k] = [ x.item() for x in v ]
|
||||
|
||||
json_write( similarities, args.out_path / "similarities.json" )
|
||||
json_write( features_json, args.out_path / "embeddings.json" )
|
||||
|
||||
# and print
|
||||
for filename, sim in similarities.items():
|
||||
print(f'{filename}: {sim}')
|
||||
else:
|
||||
# and print
|
||||
for filename, sim in similarities.items():
|
||||
print(f'{filename}: {sim}')
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from tqdm.auto import tqdm
|
|||
from pathlib import Path
|
||||
|
||||
from ..utils import coerce_dtype
|
||||
from ..utils.io import json_read, json_write
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
@ -308,7 +309,7 @@ def transcribe_batch(
|
|||
outpath = Path(f'./{output_metadata}/{dataset_name}/{speaker_id}/whisper.json')
|
||||
|
||||
if outpath.exists():
|
||||
metadata = json.loads(open(outpath, 'r', encoding='utf-8').read())
|
||||
metadata = json_read( outpath )
|
||||
else:
|
||||
os.makedirs(f'./{output_metadata}/{dataset_name}/{speaker_id}/', exist_ok=True)
|
||||
metadata = {}
|
||||
|
@ -327,7 +328,7 @@ def transcribe_batch(
|
|||
|
||||
metadata[filename] = transcribe( inpath, model_name=model_name, diarize=diarize, device=device, dtype=dtype )
|
||||
|
||||
open(outpath, 'w', encoding='utf-8').write(json.dumps(metadata))
|
||||
json_write( metadata, outpath )
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
Loading…
Reference in New Issue
Block a user