This commit is contained in:
mrq 2024-09-17 22:44:36 -05:00
parent f00283440c
commit 6ceed866b5

View File

@ -4,7 +4,7 @@
"""
import os
import json
import orjson as json
import argparse
import torch
import torchaudio
@ -53,6 +53,7 @@ def process(
verbose=False,
metadata_path=None,
top_k=8,
trim_duration=0,
min_duration=0,
@ -167,7 +168,7 @@ def process(
# to-do: support just querying for list of similar to cram into JSON metadata
if verbose:
for filename, embedding in features.items():
D, I = index.search(embedding.unsqueeze(0).cpu(), k=2)
D, I = index.search(embedding.unsqueeze(0).cpu(), k=top_k+1)
sim = list(I[0][1:])
print(f'{filename}: {sim}')
"""
@ -180,10 +181,18 @@ def process(
embeddings = torch.stack( list( features.values() ) )
sorted_similarities = {}
for filename in tqdm(keys, desc=f"Computing similarities: {speaker_path.name}"):
for index, filename in tqdm(enumerate(keys), total=len(keys), desc=f"Computing similarities: {speaker_path.name}"):
embedding = features[filename].unsqueeze(0)
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1)
# set current index to -inf
similarities[index] = float("-inf")
similarities = torch.topk(similarities, k=top_k, largest=True, sorted=True).indices.tolist()
# similarities = torch.nn.functional.cosine_similarity(embedding, embeddings, dim=1).cpu().tolist()
sorted_similarities[filename] = similarities
# sorting is slow, don't bother
#sorted_similarities[filename] = sorted([ ( i if simplified_metadata else keys[i], similarity ) for i, similarity in enumerate( similarities ) if index != i ], key=lambda x: x[1], reverse=True)
return sorted_similarities
@ -196,17 +205,19 @@ def main():
parser.add_argument("--yaml", type=Path)
parser.add_argument("--text", action="store_true")
# dropped, because this might mess with the indices to map to
"""
parser.add_argument("--trim-duration", type=float, default=3.0)
parser.add_argument("--min-duration", type=float, default=0)
parser.add_argument("--max-duration", type=float, default=0)
"""
parser.add_argument("--storage-backend", type=str, default="slop")
parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--audio-backend", type=str, default="encodec")
parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--device", type=str, default="cpu") # unironically faster
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
@ -226,6 +237,7 @@ def main():
speaker_path=cfg.data_dir / speaker_name,
yaml=args.yaml,
text=args.text,
top_k=args.top_k,
#trim_duration=args.trim_duration,
#min_duration=args.min_duration,
#max_duration=args.max_duration,
@ -252,7 +264,7 @@ def main():
metadata[filename]["similar"] = sim
with open(str(metadata_path), "w", encoding="utf-8") as f:
with open(str(metadata_path), "wb") as f:
f.write( json.dumps( metadata ) )
#f.write( truncate_json( json.dumps( metadata ) ) )
@ -273,6 +285,7 @@ def main():
speaker_path=args.input_speaker,
yaml=args.yaml,
text=args.text,
top_k=args.top_k,
#trim_duration=args.trim_duration,
#min_duration=args.min_duration,