diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 2b44156..32fb3a2 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -4,7 +4,7 @@ """ import os -import json +import orjson as json import argparse import torch import torchaudio @@ -53,6 +53,7 @@ def process( verbose=False, metadata_path=None, + top_k=8, trim_duration=0, min_duration=0, @@ -167,7 +168,7 @@ def process( # to-do: support just querying for list of similar to cram into JSON metadata if verbose: for filename, embedding in features.items(): - D, I = index.search(embedding.unsqueeze(0).cpu(), k=2) + D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1) sim = list(I[0][1:]) print(f'{filename}: {sim}') """ @@ -180,10 +181,18 @@ def process( embeddings = torch.stack( list( features.values() ) ) sorted_similarities = {} - for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"): + for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"): embedding = features[filename].unsqueeze(0) - similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() + + similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1) + # set current index to -inf + similarities[index] = float("-inf") + similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist() + # similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist() + sorted_similarities[filename] = similarities + + # sorting is slow, don't bother #sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) return sorted_similarities @@ -196,17 +205,19 @@ def main(): parser.add_argument("--yaml", type=Path) parser.add_argument("--text", action="store_true") + # dropped, because this might mess with the indices to map to """ parser.add_argument("--trim-duration", type=float, default=3.0) parser.add_argument("--min-duration", type=float, default=0) parser.add_argument("--max-duration", type=float, default=0) """ parser.add_argument("--storage-backend", type=str, default="slop") + parser.add_argument("--top-k", type=int, default=8) parser.add_argument("--audio-backend", type=str, default="encodec") - parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--dtype", type=str, default="float16") parser.add_argument("--amp", action="store_true") - parser.add_argument("--device", type=str, default="cpu") # unironically faster + parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() @@ -226,6 +237,7 @@ def main(): speaker_path=cfg.data_dir / speaker_name, yaml=args.yaml, text=args.text, + top_k=args.top_k, #trim_duration=args.trim_duration, #min_duration=args.min_duration, #max_duration=args.max_duration, @@ -252,7 +264,7 @@ def main(): metadata[filename]["similar"] = sim - with open(str(metadata_path), "w", encoding="utf-8") as f: + with open(str(metadata_path), "wb") as f: f.write( json.dumps( metadata ) ) #f.write( truncate_json( json.dumps( metadata ) ) ) @@ -273,6 +285,7 @@ def main(): speaker_path=args.input_speaker, yaml=args.yaml, text=args.text, + top_k=args.top_k, #trim_duration=args.trim_duration, #min_duration=args.min_duration,