This commit is contained in:
mrq 2024-09-17 22:44:36 -05:00
parent f00283440c
commit 6ceed866b5

View File

@ -4,7 +4,7 @@
""" """
import os import os
import json import orjson as json
import argparse import argparse
import torch import torch
import torchaudio import torchaudio
@ -53,6 +53,7 @@ def process(
verbose=False, verbose=False,
metadata_path=None, metadata_path=None,
top_k=8,
trim_duration=0, trim_duration=0,
min_duration=0, min_duration=0,
@ -167,7 +168,7 @@ def process(
# to-do: support just querying for list of similar to cram into JSON metadata # to-do: support just querying for list of similar to cram into JSON metadata
if verbose: if verbose:
for filename, embedding in features.items(): for filename, embedding in features.items():
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2) D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1)
sim = list(I[0][1:]) sim = list(I[0][1:])
print(f'{filename}: {sim}') print(f'{filename}: {sim}')
""" """
@ -180,10 +181,18 @@ def process(
embeddings = torch.stack( list( features.values() ) ) embeddings = torch.stack( list( features.values() ) )
sorted_similarities = {} sorted_similarities = {}
for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"): for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
embedding = features[filename].unsqueeze(0) embedding = features[filename].unsqueeze(0)
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
# set current index to -inf
similarities[index] = float("-inf")
similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist()
# similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
sorted_similarities[filename] = similarities sorted_similarities[filename] = similarities
# sorting is slow, don't bother
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True) #sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
return sorted_similarities return sorted_similarities
@ -196,17 +205,19 @@ def main():
parser.add_argument("--yaml", type=Path) parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true") parser.add_argument("--text", action="store_true")
# dropped, because this might mess with the indices to map to
""" """
parser.add_argument("--trim-duration", type=float, default=3.0) parser.add_argument("--trim-duration", type=float, default=3.0)
parser.add_argument("--min-duration", type=float, default=0) parser.add_argument("--min-duration", type=float, default=0)
parser.add_argument("--max-duration", type=float, default=0) parser.add_argument("--max-duration", type=float, default=0)
""" """
parser.add_argument("--storage-backend", type=str, default="slop") parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--audio-backend", type=str, default="encodec") parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="float32") parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cpu") # unironically faster parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args() args = parser.parse_args()
@ -226,6 +237,7 @@ def main():
speaker_path=cfg.data_dir / speaker_name, speaker_path=cfg.data_dir / speaker_name,
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
top_k=args.top_k,
#trim_duration=args.trim_duration, #trim_duration=args.trim_duration,
#min_duration=args.min_duration, #min_duration=args.min_duration,
#max_duration=args.max_duration, #max_duration=args.max_duration,
@ -252,7 +264,7 @@ def main():
metadata[filename]["similar"] = sim metadata[filename]["similar"] = sim
with open(str(metadata_path), "w", encoding="utf-8") as f: with open(str(metadata_path), "wb") as f:
f.write( json.dumps( metadata ) ) f.write( json.dumps( metadata ) )
#f.write( truncate_json( json.dumps( metadata ) ) ) #f.write( truncate_json( json.dumps( metadata ) ) )
@ -273,6 +285,7 @@ def main():
speaker_path=args.input_speaker, speaker_path=args.input_speaker,
yaml=args.yaml, yaml=args.yaml,
text=args.text, text=args.text,
top_k=args.top_k,
#trim_duration=args.trim_duration, #trim_duration=args.trim_duration,
#min_duration=args.min_duration, #min_duration=args.min_duration,